136
Applications in Natural Language Processing
both on expectation and standard deviation compared to the full-precision baseline and
the ternary model. For instance, the top-1 eigenvalues of MHA-O in the binary model are
∼15× larger than the full-precision counterpart. Therefore, the quantization loss increases
of full-precision and ternary model are tighter bounded than the binary model in Eq. (5.19).
The highly complex and irregular landscape by binarization thus poses more challenges to
the optimization.
5.7.1
Ternary Weight Splitting
Given the challenging loss landscape of binary BERT, the authors proposed ternary weight
splitting (TWS) that exploits the flatness of ternary loss landscape as the optimization proxy
of the binary model. As is shown in Fig. 2.4, they first train the half-sized ternary BERT
to convergence, and then split both the latent full-precision weight Wt and quantized ˆ
Wt
to their binary counterparts Wb
1, Wb
2 and ˆ
Wb
1, ˆ
Wb
2 via the TWS operator. To inherit the
performance of the ternary model after splitting, the TWS operator requires the splitting
equivalency ( i.e., the same output given the same input):
Wt = Wb
1 + Wb
2,
ˆ
Wt = ˆ
Wb
1 + ˆ
Wb
2 .
(5.20)
While solution to Eq. (5.20) is not unique, the latent full-precision weights Wb
1, Wb
2 are
constrained after splitting to satisfy Wt = Wb
1 + Wb
2 as
Wb
1,i =
⎧
⎨
⎩
a · Wt
i
if
ˆ
Wt
i ̸= 0
b + Wt
i
if
ˆ
Wt
i = 0, Wt
i > 0
b
otherwise
,
(5.21)
Wb
2,i =
⎧
⎨
⎩
(1−a)Wt
i
if
ˆ
Wt
i ̸= 0
−b
if
ˆ
Wt
i = 0, Wt
i > 0
−b + Wt
i
otherwise
,
(5.22)
where a and b are the variables to solve. By Eq. (5.21) and Eq. (5.22) with ˆ
Wt = ˆ
Wb
1 + ˆ
Wb
2,
we get
a =
i∈I |Wt
i| +
j∈J |Wt
j| −
k∈K |Wt
k|
2
i∈I |Wt
i|
,
b =
n
|I|
i∈I |Wt
i| −n
i=1 |Wt
i|
2(|J | + |K|)
,
(5.23)
where we denote I = {i | ˆ
Wt
i ̸= 0}, J = {j | ˆ
Wt
j = 0 and Wt
j > 0} and K = {k | ˆ
Wt
k =
0 and Wt
k < 0}. | · | denotes the cardinality of the set.
5.7.2
Knowledge Distillation
Further, the authors proposed to boost the performance of binarized BERT by Knowledge
Distillation (KD), which is shown to benefit BERT quantization [285]. Following [106, 285],
they first performed intermediate-layer distillation from the full-precision teacher network’s
embedding E, layer-wise MHA output Ml and FFN output Fl to the quantized student
counterpart ˆE, ˆMl, ˆFl (l = 1, 2, ...L). To minimize their mean squared errors, i.e., ℓemb =
MSE(ˆE, E), ℓmha =
l MSE( ˆMl, Ml), and ℓffn =
l MSE(ˆFl, Fl), the objective function
falls in
ℓint = ℓemb + ℓmha + ℓffn.
(5.24)